import asyncio
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from utils.trajectory_logger import TrajectoryLogger
from policies.utils.transition import Transition

def rollout(env,
            policies,
            num_episodes=1,
            metrics=None,
            logdir='results/cooperative_perception_overtake/eval/test/',
            data_processing_function=None,
):
    agents_ids = [agent_id for agent_id in policies.keys()]
    episode_returns = {agent_id: [] for agent_id in agents_ids}

    # Create a ThreadPoolExecutor
    executor = ThreadPoolExecutor(max_workers=2)

    async def rollout_episode():
        for i in range(num_episodes):
            logger = TrajectoryLogger(logdir, frequency=10)
            env.logdir = logger.get_logdir()
            rewards = {agent_id: 0 for agent_id in agents_ids}
            episode_rewards = {agent_id: 0 for agent_id in agents_ids}
            terminated = {agent_id: False for agent_id in agents_ids}
            truncated = {agent_id: False for agent_id in agents_ids}

            observations, info = env.reset()
            for agent_id, policy in policies.items():
                policy.reset()

            # main loop
            done = False
            while not done:
                actions = {}
                # Prepare a list to hold the coroutines
                coroutines = []
                for agent_id, obs in observations.items():
                    if agent_id in policies:
                        policies[agent_id].observe(
                                                obs,
                                                rewards[agent_id],
                                                terminated[agent_id],
                                                truncated[agent_id],
                                                info,
                                                )
                        # Schedule the act method to run in a separate process
                        loop = asyncio.get_event_loop()
                        task = loop.run_in_executor(executor, policies[agent_id].act)
                        coroutines.append((agent_id, obs, task))

                # Run the coroutines concurrently
                results = await asyncio.gather(
                    *(task for _, _, task in coroutines)
                )

                # Collect the actions and log
                for idx, (agent_id, obs, _) in enumerate(coroutines):
                    action = results[idx]
                    actions[agent_id] = action
                    logger.log(
                        frame_id=env.step_count,
                        agent_id=agent_id,
                        info={'obs': obs, 'action': action})

                # step the environment
                obs_next, rewards, terminated, truncated, info = env.step(actions)
                done = terminated['__all__'] or truncated['__all__']

                # store transitions
                for agent_id, obs in observations.items():
                    if agent_id in policies and (
                        policies[agent_id].step_count > policies[agent_id].skip_frames) and (
                        policies[agent_id].step_count % policies[agent_id].decision_frequency == 1
                        ):
                        transition = Transition(
                                                agent_id,
                                                env.step_count,
                                                obs,
                                                actions[agent_id],
                                                rewards[agent_id],
                                                done,
                                                info,
                                            )
                        policies[agent_id].store_transition(transition)

                # update episodic reward for each agent
                for agent_id, reward in rewards.items():
                    if agent_id in episode_rewards:
                        episode_rewards[agent_id] += reward

                # update observations
                observations = obs_next

            if data_processing_function is not None:
                data_processing_function(policies, env)

            # update metrics
            if metrics is not None:
                metrics.update(info)
            for agent_id, reward in episode_rewards.items():
                episode_returns[agent_id].append(reward)

    # Run the asynchronous rollout
    loop = asyncio.get_event_loop()
    loop.run_until_complete(rollout_episode())

    # Shutdown the executor
    executor.shutdown()
    return episode_returns

def process_learning_data(policies, env):
    # Unified data structure to organize joint trajectory data
    all_data = {}
    # Process data according to time step
    for agent_id, policy in policies.items():
        if policy.replay_buffer is None:
            continue
        # Get the replay buffer data and store it in all_data
        data = policy.replay_buffer.memory
        for transition in data:
            t = transition.t
            if t not in all_data:
                all_data[t] = {}
            # Store the transition in all_data at time step t
            all_data[t][agent_id] = transition

    # Get environment feedback
    feedback = env.evaluator.get_language_feedback()
    collision_info = feedback.get('collision_info', {})
    collision_occurred = collision_info.get('collision_occurred', False)
    collision_time = collision_info.get('collision_time', None)

    # Post-process data according to time step
    timesteps = sorted(all_data.keys())
    for i in range(len(timesteps)):
        t = timesteps[i]
        t_next = timesteps[i + 1] if i + 1 < len(timesteps) else None
        timestep_data = all_data[t]
        next_timestep_data = all_data.get(t_next, None)
        # Process data for each agent at time step t
        for agent_id, transition in timestep_data.items():
            # Collect reaction data from other agents
            other_reactions = {}
            if next_timestep_data is not None:
                for other_agent_id, other_transition in next_timestep_data.items():
                    if other_agent_id != agent_id:
                        # Store the action taken by other agents
                        other_commands = other_transition.action.get('command', None)
                        other_messages = other_transition.action.get('message', None)
                        other_reasoning = other_transition.action.get('reasoning', None)
                        other_reactions[other_agent_id] = {
                            'command': other_commands,
                            'message': other_messages,
                            'reasoning': other_reasoning
                        }
                    else:
                        # Store the next observation for the current agent
                        transition.obs_next = other_transition.obs
            transition.other_reactions = other_reactions

            # collision information
            if collision_occurred and collision_time is not None:
                time_to_collision = max(0, collision_time - t)/env.frame_rate # in seconds
                transition.time_to_collision = time_to_collision
            else:
                transition.time_to_collision = None

            # Store the feedback from the environment
            transition.feedback = feedback
            transition.vehicle_id = feedback['agent_vehicle_mapping'][agent_id]